# import math
#
# import torch
# import torch.nn.functional as F
# import networkx as nx
# from continual_rl.utils.utils import Utils

from compute_distribution import wasserstein_distance_tasks


import networkx as nx
import logging
import math
import torch
import torch.nn.functional as F
# from scipy.stats import wasserstein_distance


from compute_distribution import wasserstein_distance_tasks
# import networkx as nx
# import logging
# import math
# import torch
# import torch.nn.functional as F
import numpy as np
import time

import logging
import math
import torch



class KnowledgeDistiller:
    def __init__(self, nodes, config, policy):
        self.nodes = nodes
        self.config = config or {}
        self.policy = policy
        self.logger = self._create_logger()  # 创建专用的日志器
        self.node_to_index = {}  # 节点唯一ID到索引的映射
        self._build_index_mapping()
        self.knowledge_graph = self._build_knowledge_graph()

        # 设置蒸馏参数
        self.alpha = self.config.get('alpha', 0.3)  # 蒸馏损失权重
        self.k = self.config.get('k', 2)  # 邻居数量
        self.update_count = 0
        self.update_freq = self.config.get('update_freq', 100)  # 更新频率

        # 记录初始状态
        self.logger.info(f"KnowledgeDistiller initialized with {len(self.nodes)} nodes")
        self.logger.info(
            f"Knowledge Graph Built: Nodes={len(self.knowledge_graph.nodes)}, Edges={len(self.knowledge_graph.edges)}")

    def _create_logger(self):
        """创建专用的日志记录器"""
        logger = logging.getLogger('KnowledgeDistiller')
        logger.setLevel(logging.DEBUG)

        # 创建文件处理程序
        file_handler = logging.FileHandler(
            '/tmp/distill.log')
        file_handler.setLevel(logging.DEBUG)

        # 创建控制台处理程序
        console_handler = logging.StreamHandler()
        console_handler.setLevel(logging.INFO)

        # 创建格式化器并添加到处理程序
        formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
        file_handler.setFormatter(formatter)
        console_handler.setFormatter(formatter)

        # 添加处理程序到日志器
        logger.addHandler(file_handler)
        logger.addHandler(console_handler)

        return logger

    def _build_index_mapping(self):
        """创建节点唯一ID到列表索引的映射"""
        self.logger.info("Building node index mapping...")
        self.node_to_index = {}
        for idx, node in enumerate(self.nodes):
            if node.unique_id not in self.node_to_index:
                self.node_to_index[node.unique_id] = idx
                self.logger.debug(f"Mapped {node.unique_id} to index {idx}")

    def _build_knowledge_graph(self):
        """基于Wasserstein距离构建知识图谱"""
        self.logger.info("Building knowledge graph...")
        graph = nx.Graph()

        # 添加所有节点
        for idx, node in enumerate(self.nodes):
            graph.add_node(idx, node=node)
            self.logger.debug(f"Added node {node.unique_id} at index {idx}")

        # 计算并添加边
        total_edges = 0
        for i, node_i in enumerate(self.nodes):
            for j, node_j in enumerate(self.nodes):
                if i < j:  # 避免重复计算
                    # 使用Wasserstein距离计算相似度
                    sim = self._compute_similarity(node_i, node_j)
                    graph.add_edge(i, j, weight=sim)
                    total_edges += 1
                    self.logger.debug(
                        f"Added edge between {node_i.unique_id}({i}) and {node_j.unique_id}({j}) with similarity {sim:.4f}")

        self.logger.info(f"Built knowledge graph with {total_edges} edges")
        return graph

    def _get_device(self):
        """获取模型设备"""
        # try:
        # if hasattr(self.nodes[0], 'impala_trainer') and hasattr(self.nodes[0].impala_trainer, 'actor_model'):
        #     return next(iter(self.nodes[0].impala_trainer.actor_model.parameters())).device
        # # device = self.nodes[0].impala_trainer._model_flags.device
        # else:
        #     # 回退到CPU
        #     self.logger.warning("Could not determine device, using CPU as fallback")
        device =  torch.device('cpu')
        # except Exception as e:
        #     self.logger.error(f"Error determining device: {str(e)}")
        return device

    def _compute_similarity(self, node_i, node_j):
        """使用Wasserstein距离计算节点特征相似度"""
        # 确保节点特征存在
        if (not hasattr(self.policy, '_node_features') or
                node_i.unique_id not in self.policy._node_features or
                node_j.unique_id not in self.policy._node_features):
            self.logger.warning(f"Node features missing for {node_i.unique_id} or {node_j.unique_id}")
            return 0.5  # 默认相似度

        features_i = self.policy._node_features[node_i.unique_id]
        features_j = self.policy._node_features[node_j.unique_id]

        try:
            # 计算Wasserstein距离
            w_dist = wasserstein_distance_tasks(features_i, features_j)

            # 距离越小越相似，使用指数转换确保[0,1]范围
            similarity = math.exp(-w_dist * self.config.get('wasserstein_scale', 1.0))

            # 确保相似度在合理范围内
            if similarity < 0 or math.isnan(similarity):
                similarity = 0.5

            self.logger.debug(
                f"Similarity between {node_i.unique_id} and {node_j.unique_id}: {similarity:.4f} (dist={w_dist:.4f})")
            return similarity
        except Exception as e:
            self.logger.error(f"Error computing similarity: {str(e)}")
            return 0.5  # 默认值

    def update_graph(self):
        """更新知识图谱"""
        self.logger.info("Updating knowledge graph...")
        self._build_index_mapping()  # 在更新前重建索引
        self.knowledge_graph = self._build_knowledge_graph()
        self.logger.info("Knowledge graph updated")

    # def get_neighbors(self, node):
    #     """获取最相关邻居节点"""
    #     try:
    #         if node.unique_id not in self.node_to_index:
    #             self.logger.error(f"Node {node.unique_id} not found in node_to_index mapping")
    #             return []
    #
    #         node_index = self.node_to_index[node.unique_id]
    #
    #         if node_index not in self.knowledge_graph:
    #             self.logger.error(f"Node index {node_index} not found in graph")
    #             return []
    #
    #         # 获取邻居信息并排序
    #         neighbors = []
    #         for neighbor_idx, attrs in self.knowledge_graph[node_index].items():
    #             neighbor_node = self.nodes[neighbor_idx]
    #             weight = attrs.get('weight', 0.0)
    #             neighbors.append((neighbor_idx, {'weight': weight, 'node': neighbor_node}))
    #
    #         # 按权重降序排序
    #         neighbors.sort(key=lambda x: x[1]['weight'], reverse=True)
    #
    #         # 只保留前k个邻居
    #         neighbors = neighbors[:self.k]
    #
    #         neighbor_names = [f"{self.nodes[idx].unique_id}({weight:.2f})" for idx, weight in neighbors]
    #         self.logger.info(f"Found {len(neighbors)} neighbors for {node.unique_id}: {neighbor_names}")
    #
    #         return neighbors
    #     except Exception as e:
    #         self.logger.error(f"Error getting neighbors: {str(e)}")
    #         return []

    def get_neighbors(self, node):
        """获取最相关邻居节点"""
        try:
            if node.unique_id not in self.node_to_index:
                self.logger.error(f"Node {node.unique_id} not found in node_to_index mapping")
                return []

            node_index = self.node_to_index[node.unique_id]

            if node_index not in self.knowledge_graph:
                self.logger.error(f"Node index {node_index} not found in graph")
                return []

            # 获取邻居信息并排序
            neighbors = []
            for neighbor_idx, attrs in self.knowledge_graph[node_index].items():
                neighbor_node = self.nodes[neighbor_idx]
                weight = attrs.get('weight', 0.0)
                neighbors.append((neighbor_idx, {'weight': weight, 'node': neighbor_node}))

            # 按权重降序排序
            neighbors.sort(key=lambda x: x[1]['weight'], reverse=True)

            # 只保留前k个邻居
            neighbors = neighbors[:self.k]

            # 构建邻居名称列表用于日志
            neighbor_names = [f"{attrs['node'].unique_id}({attrs['weight']:.2f})" for _, attrs in neighbors]
            self.logger.info(f"Found {len(neighbors)} neighbors for {node.unique_id}: {', '.join(neighbor_names)}")

            return neighbors
        except Exception as e:
            self.logger.error(f"Error getting neighbors: {str(e)}")
            return []


    def should_update_graph(self):
        """根据更新频率决定是否更新知识图"""
        self.update_count += 1
        if self.update_count % self.update_freq == 0:
            self.logger.info(f"Update count {self.update_count} reached, triggering graph update")
            return True
        return False

    def safe_kl_div(self, active_logits, neighbor_logits, temperature=1.0):
        """安全的KL散度计算"""
        # 添加温度控制
        soft_logits1 = F.log_softmax(active_logits / temperature, dim=-1)
        soft_logits2 = F.softmax(neighbor_logits / temperature, dim=-1)

        # 数值稳定性处理
        soft_logits2 = torch.clamp(soft_logits2, 1e-7, 1)

        # 计算KL散度
        kl = F.kl_div(
            soft_logits1,
            soft_logits2.detach(),  # 不传播梯度到邻居模型
            reduction='batchmean'
        ) * (temperature ** 2)  # 补偿温度缩放

        self.logger.debug(f"KL loss value: {kl.item():.6f}")
        return kl

    def distill(self, active_node, batch, task_features):
        """计算知识蒸馏损失"""
        task_id = len(task_features)-1
        task_id = int(task_id)

        # if hasattr(active_node, 'current_task') and active_node.current_task is not None:
        #     task_id = len(task_features)
        # else:
        #     task_id = 0  # 默认值
        #
        #     # 确保task_id是整数
        # try:
        #     task_id = int(task_id)
        # except (TypeError, ValueError) as e:
        #     self.logger.error(f"Error converting task_id to int: {task_id}, error: {str(e)}")
        #     task_id = 0

        self.logger.info(f"Using task_id as action_space_id: {task_id}")

        # 使用节点当前的任务ID作为动作空间ID
        action_space_id = task_id
        # action_space_id = action_space_id.long()


        # # 打印批处理键和action_space_id是否存在
        # if 'action_space_id' not in batch:
        #     self.logger.error("action_space_id not in batch. Batch keys: %s", batch.keys())
        # else:
        #     self.logger.info("Found action_space_id in batch: %s", batch['action_space_id'])
        #
        #
        # self.logger.debug("Batch keys available: " + ", ".join(batch.keys()))
        # self.logger.debug(f"Node {active_node.unique_id} has action space: {active_node._action_spaces}")
        # ...继续蒸馏过程...

        device = self._get_device()
        # self.logger.debug(f"Using device: {device}")
        self.logger.info(f"Distillation started for node {active_node.unique_id} on device {device}")

        # 检查是否需要更新知识图
        if self.should_update_graph():
            self.update_graph()

        # 检查图谱节点数
        if self.knowledge_graph.number_of_nodes() < 2:  # 需要至少2个节点
            self.logger.warning("Knowledge graph has fewer than 2 nodes, skipping distillation")
            return torch.tensor(0.0).to(device)

        # 获取邻居节点
        neighbors = self.get_neighbors(active_node)
        if not neighbors:  # 如果没有邻居
            self.logger.warning(f"No neighbors found for node {active_node.unique_id}")
            return torch.tensor(0.0).to(device)


        distillation_loss = torch.tensor(0.0).to(device)
        valid_loss_count = 0  # 记录有效损失计数

        # 获取当前节点的logits
        try:

            frame_data = batch

            active_logits = active_node.policy_forward(frame_data, action_space_id)
            print(active_logits.shape)
            print(active_logits)
        except Exception as e:
            self.logger.error(f"Active node forward pass failed: {str(e)}")
            return torch.tensor(0.0).to(device)

        for neighbor_idx, attrs in neighbors:
            neighbor_node = attrs['node']
            weight = attrs['weight']  # 边权值

            # 跳过无效权重
            if weight <= 0 or math.isnan(weight):
                self.logger.debug(f"Skipping neighbor {neighbor_node.unique_id} with invalid weight {weight}")
                continue

            try:
                # 使用邻居节点策略
                with torch.no_grad():  # 不计算邻居模型的梯度
                    neighbor_logits = neighbor_node.policy_forward(frame_data, action_space_id)

                # 检查logits形状
                if active_logits.shape != neighbor_logits.shape:
                    self.logger.warning(
                        f"Shape mismatch: active={active_logits.shape}, neighbor={neighbor_logits.shape}")
                    continue

                # 检查数值有效性
                if torch.isnan(neighbor_logits).any() or torch.isinf(neighbor_logits).any():
                    self.logger.warning(f"Invalid values in neighbor logits: {neighbor_node.unique_id}")
                    continue

                # 使用安全的KL散度计算
                kl_loss = self.safe_kl_div(active_logits, neighbor_logits)

                if kl_loss.item() <= 0:
                    self.logger.debug(f"Invalid KL loss: {kl_loss.item()} for {neighbor_node.unique_id}")
                    continue

                # 加权KL损失（基于相似度）
                weighted_kl = weight * kl_loss
                distillation_loss += weighted_kl
                valid_loss_count += 1

                self.logger.debug(f"KL loss with {neighbor_node.unique_id}: {kl_loss.item():.6f} (weight={weight:.4f})")
            except Exception as e:
                self.logger.error(f"Error processing neighbor {neighbor_node.unique_id}: {str(e)}")

        # 如果没有有效损失
        if valid_loss_count == 0:
            self.logger.warning("No valid KL losses calculated")
            return torch.tensor(0.0).to(device)

        # 平均损失并应用权重
        avg_loss = distillation_loss / valid_loss_count
        final_loss = self.alpha * avg_loss

        # 确保损失值保持梯度连接
        final_loss = final_loss.clone()  # 防止意外断开梯度
        final_loss.requires_grad_(True)

        self.logger.info(f"Final distill loss: {final_loss.item():.6f} (from {valid_loss_count} neighbors)")

        return final_loss



















# class KnowledgeDistiller:
#     def __init__(self, nodes, config, policy):
#         self.nodes = nodes
#         self.config = config or {}
#         # self.knowledge_graph = self._build_knowledge_graph()
#         self.policy = policy
#         self.knowledge_graph = self._build_knowledge_graph()
#
#         # 设置蒸馏参数
#         self.alpha = self.config.get('alpha', 0.3)  # 蒸馏损失权重
#         self.k = self.config.get('k', 3)  # 邻居数量
#         self.update_count = 0
#         self.update_freq = self.config.get('update_freq', 100)  # 更新频率
#         self.dir = '/home/hhz/文档/tmp/My_sane（3）/tmp'
#
#     def _build_knowledge_graph(self):
#         """基于Wasserstein距离构建知识图谱"""
#         graph = nx.Graph()
#         for i, node_i in enumerate(self.nodes):
#             graph.add_node(i)
#             for j, node_j in enumerate(self.nodes):
#                 if i < j:  # 避免重复计算
#                     # 使用Wasserstein距离计算相似度
#                     sim = self._compute_similarity(node_i, node_j)
#                     graph.add_edge(i, j, weight=sim)
#
#         # self._logger.info(
#         #     f"Knowledge Graph Built: Nodes={len(self.knowledge_graph.nodes)}, Edges={len(self.knowledge_graph.edges)}")
#         return graph
#
#     def get_neighbors(self, node_id):
#         """获取最相关邻居节点"""
#         neighbors = []
#         if node_id in self.knowledge_graph:
#             neighbors = sorted(
#                 self.knowledge_graph[node_id].items(),
#                 key=lambda x: x[1]['weight'],
#                 reverse=True
#             )
#         return neighbors[:self.k]
#
#     def _get_device(self):
#         """获取模型设备"""
#         return next(iter(self.nodes[0].impala_trainer.actor_model.parameters())).device
#
#     def _compute_similarity(self, node_i, node_j):
#         """使用Wasserstein距离计算节点特征相似度"""
#         # 确保节点特征存在
#         if (node_i.unique_id not in self.policy._node_features or
#                 node_j.unique_id not in self.policy._node_features):
#             return 0.0  # 默认相似度
#
#         features_i = self.policy._node_features[node_i.unique_id]
#         features_j = self.policy._node_features[node_j.unique_id]
#
#         # 计算Wasserstein距离并转换为相似度
#         w_dist = wasserstein_distance_tasks(features_i, features_j)
#         # 距离越小越相似，使用指数转换确保[0,1]范围
#         similarity = math.exp(-w_dist * self.config.get('wasserstein_scale', 1.0))
#         return similarity
#
#     # def _build_knowledge_graph(self):
#     #     """动态构建知识图谱"""
#     #     graph = nx.Graph()
#     #     for i, node_i in enumerate(self.nodes):
#     #         graph.add_node(i)
#     #         for j, node_j in enumerate(self.nodes):
#     #             if i < j:  # 避免重复计算
#     #                 sim = self._compute_policy_similarity(node_i, node_j)
#     #                 graph.add_edge(i, j, weight=sim)
#     #     return graph
#
#     def update_graph(self):
#         """更新知识图谱"""
#         self.knowledge_graph = self._build_knowledge_graph()
#
#     # def get_neighbors(self, node_id):
#     #     """获取最相关邻居节点"""
#     #     neighbors = []
#     #     if node_id in self.knowledge_graph:
#     #         neighbors = sorted(
#     #             self.knowledge_graph[node_id].items(),
#     #             key=lambda x: x[1]['weight'],
#     #             reverse=True
#     #         )
#     #     return neighbors[:self.k]
#
#     def should_update_graph(self):
#         """根据更新频率决定是否更新知识图"""
#         self.update_count += 1
#         return self.update_count % self.update_freq == 0
#
#     @property
#     def _logger(self):
#         logger = Utils.create_logger(f"{self.dir}/distill.log")
#         return logger
#
#     def distill(self, active_node, batch):
#         """计算知识蒸馏损失"""
#         device = self._get_device()
#
#         # 检查是否需要更新知识图
#         if self.should_update_graph():
#             self.update_graph()
#
#         if len(self.knowledge_graph.nodes) < 2:  # 需要至少2个节点
#             return torch.tensor(0.0).to(device)
#
#             # 检查相似节点选择
#         neighbors = self.get_neighbors(active_node.id)
#         if not neighbors:  # 如果没有邻居
#             return torch.tensor(0.0).to(device)
#
#         # 添加日志输出
#         self._logger.info(f"[Distill] Active node: {active_node.unique_id}, Neighbors: {[n.unique_id for n in neighbors]}")
#         print(
#             f"[Distill] Active node: {active_node.unique_id}, Neighbors: {[n.unique_id for n in neighbors]}")
#
#         # # 获取邻居节点
#         # neighbors = self.get_neighbors(active_node.id)
#         #
#         # # 如果没有邻居，返回0损失
#         # if not neighbors:
#         #     return torch.tensor(0.0).to(device)
#
#         distillation_loss = torch.tensor(0.0).to(device)
#
#         for neighbor_id, attrs in neighbors:
#             neighbor_node = self.nodes[neighbor_id]
#
#             # 确保数据和模型在同一设备
#             frame = batch['frame'].to(device)
#             action_space_id = batch['action_space_id'].to(device)
#
#             # 使用邻居节点策略
#             with torch.no_grad():
#                 neighbor_logits = neighbor_node.policy_forward(frame, action_space_id)
#
#             # 使用当前节点策略
#             active_logits = active_node.policy_forward(frame, action_space_id)
#
#             if active_logits.shape != neighbor_logits.shape:
#                 self._logger(f"Shape mismatch: active={active_logits.shape}, neighbor={neighbor_logits.shape}")
#                 print(f"Shape mismatch: active={active_logits.shape}, neighbor={neighbor_logits.shape}")
#                 continue
#
#                 # 检查数值有效性
#             if torch.isnan(neighbor_logits).any():
#                 self._logger("NaN in neighbor logits")
#
#             # 添加KL散度损失
#             kl_loss = F.kl_div(
#                 F.log_softmax(active_logits, dim=-1),
#                 F.softmax(neighbor_logits.detach(), dim=-1),  # 避免梯度传播到邻居节点
#                 reduction='batchmean'
#             )
#             self._logger(f"KL loss with {neighbor_node.unique_id}: {kl_loss.item()}")
#
#             # 加权KL损失（基于相似度）
#             weight = attrs['weight']  # 边权值
#             distillation_loss += weight * kl_loss
#
#         # 平均损失并应用权重
#         return self.alpha * distillation_loss / len(neighbors)



# class NodeSimilarityCalculator:
#     def __init__(self, config):
#         self.config = config
#         # 相似度权重配置
#         self.weight_param = config.get("sim_weight_param", 0.4)  # 参数相似度权重
#         self.weight_feature = config.get("sim_weight_feature", 0.6)  # 特征相似度权重
#
#         # 特征处理配置
#         self.feature_sample_size = 200  # 从特征矩阵中采样的点数
#         self.device = self._get_default_device()
#
#     def _get_default_device(self):
#         return torch.device("cuda" if torch.cuda.is_available() else "cpu")
#
#     def _compute_param_similarity(self, node_i, node_j):
#         """计算节点参数相似度"""
#         params_i = dict(node_i.impala_trainer.actor_model.named_parameters())
#         params_j = dict(node_j.impala_trainer.actor_model.named_parameters())
#
#         total_sim = 0.0
#         count = 0
#         for name, param_i in params_i.items():
#             if name in params_j:
#                 param_j = params_j[name].data
#                 # 参数对齐处理
#                 if param_i.shape != param_j.shape:
#                     min_size = min(param_i.numel(), param_j.numel())
#                     param_i = param_i.flatten()[:min_size]
#                     param_j = param_j.flatten()[:min_size]
#
#                 cos_sim = F.cosine_similarity(
#                     param_i.unsqueeze(0),
#                     param_j.unsqueeze(0),
#                     dim=1
#                 )
#                 total_sim += cos_sim.item()
#                 count += 1
#
#         return total_sim / count if count > 0 else 0.0
#
#     def _compute_feature_similarity(self, features_i, features_j):
#         """计算节点特征相似度(Wasserstein距离)"""
#         # 特征矩阵采样(避免计算过大矩阵)
#         if features_i.size(0) > self.feature_sample_size:
#             idx_i = torch.randperm(features_i.size(0))[:self.feature_sample_size]
#             features_i = features_i[idx_i]
#         if features_j.size(0) > self.feature_sample_size:
#             idx_j = torch.randperm(features_j.size(0))[:self.feature_sample_size]
#             features_j = features_j[idx_j]
#
#         # 计算Wasserstein距离
#         wasserstein_dist = self._wasserstein_distance(features_i, features_j)
#
#         # 转换为相似度(距离越小，相似度越高)
#         # 使用指数核函数确保值在0-1范围
#         feature_sim = torch.exp(-wasserstein_dist * 0.5)
#         return feature_sim.item()
#
#     def _wasserstein_distance(self, X, Y):
#         """高效Wasserstein距离计算(一阶近似)"""
#         return torch.norm(torch.mean(X, dim=0) - torch.mean(Y, dim=0), p=2)
#
#     def compute_combined_similarity(self, policy, node_i, node_j):
#         """
#         计算节点综合相似度：
#         - node_i, node_j: 节点对象
#         """
#         # 节点参数相似度
#         param_sim = self._compute_param_similarity(node_i, node_j)
#
#         # 节点特征相似度
#         if (node_i.unique_id in policy._node_features and
#                 node_j.unique_id in policy._node_features):
#             features_i = policy._node_features[node_i.unique_id].to(self.device)
#             features_j = policy._node_features[node_j.unique_id].to(self.device)
#             feature_sim = self._compute_feature_similarity(features_i, features_j)
#         else:
#             feature_sim = 0.0
#
#         # 加权组合
#         combined_sim = (
#                 self.weight_param * param_sim +
#                 self.weight_feature * feature_sim
#         )
#
#         # 归一化处理
#         total_weight = self.weight_param + self.weight_feature
#         return combined_sim / total_weight




